import types
from typing import List, Optional, TYPE_CHECKING

from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
    AmazonInvokeConfig,
)
from litellm.llms.bedrock.common_utils import BedrockError

if TYPE_CHECKING:
    from litellm.types.utils import ModelResponse


class AmazonMistralConfig(AmazonInvokeConfig, BaseConfig):
    """
    Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html
    Supported Params for the Amazon / Mistral models:

    - `max_tokens` (integer) max tokens,
    - `temperature` (float) temperature for model,
    - `top_p` (float) top p for model
    - `stop` [string] A list of stop sequences that if generated by the model, stops the model from generating further output.
    - `top_k` (float) top k for model
    """

    max_tokens: Optional[int] = None
    temperature: Optional[float] = None
    top_p: Optional[float] = None
    top_k: Optional[float] = None
    stop: Optional[List[str]] = None

    def __init__(
        self,
        max_tokens: Optional[int] = None,
        temperature: Optional[float] = None,
        top_p: Optional[int] = None,
        top_k: Optional[float] = None,
        stop: Optional[List[str]] = None,
    ) -> None:
        locals_ = locals().copy()
        for key, value in locals_.items():
            if key != "self" and value is not None:
                setattr(self.__class__, key, value)

        AmazonInvokeConfig.__init__(self)

    @classmethod
    def get_config(cls):
        return {
            k: v
            for k, v in cls.__dict__.items()
            if not k.startswith("__")
            and not k.startswith("_abc")
            and not isinstance(
                v,
                (
                    types.FunctionType,
                    types.BuiltinFunctionType,
                    classmethod,
                    staticmethod,
                ),
            )
            and v is not None
        }

    def get_supported_openai_params(self, model: str) -> List[str]:
        return ["max_tokens", "temperature", "top_p", "stop", "stream"]

    def map_openai_params(
        self,
        non_default_params: dict,
        optional_params: dict,
        model: str,
        drop_params: bool,
    ) -> dict:
        for k, v in non_default_params.items():
            if k == "max_tokens":
                optional_params["max_tokens"] = v
            if k == "temperature":
                optional_params["temperature"] = v
            if k == "top_p":
                optional_params["top_p"] = v
            if k == "stop":
                optional_params["stop"] = v
            if k == "stream":
                optional_params["stream"] = v
        return optional_params

    @staticmethod
    def get_outputText(completion_response: dict, model_response: "ModelResponse") -> str:
        """This function extracts the output text from a bedrock mistral completion.
        As a side effect, it updates the finish reason for a model response.

        Args:
            completion_response: JSON from the completion.
            model_response: ModelResponse

        Returns:
            A string with the response of the LLM

        """
        if "choices" in completion_response:
            outputText = completion_response["choices"][0]["message"]["content"]
            model_response.choices[0].finish_reason = completion_response["choices"][0]["finish_reason"]
        elif "outputs" in completion_response:
            outputText = completion_response["outputs"][0]["text"]
            model_response.choices[0].finish_reason = completion_response["outputs"][0]["stop_reason"]
        else:
            raise BedrockError(message="Unexpected mistral completion response", status_code=400)

        return outputText
